import nltk
from nltk.corpus import words
from nltk import ngrams

vocabulary = set(words.words())

def calc_score(seq):
    global vocabulary
    if "sorry" in seq:
        return 0
    seq = seq.split()
    correct = 0
    total = 0
    for word in seq:
        if (len(word) == 1): # single character bs generations
            total += 1
            continue
        else:
            if word in vocabulary:
                correct += 1
            total += 1
    return correct / total
normallines = [eval(line)[0].split("<|assistant|>")[-1] for line in open("normal_olmo_raw.txt", "r").readlines()]
perturbedlines = [eval(line)[0].split("<|assistant|>")[-1] for line in open("perturbed_olmo_raw.txt", "r").readlines()]
altnormallines = [eval(line)[0].split("<|assistant|>")[-1] for line in open("normal_olmo_alt.txt", "r").readlines()]
altperturbedlines = [eval(line)[0].split("<|assistant|>")[-1] for line in open("perturbed_olmo_alt.txt", "r").readlines()]
prompts = [line.rstrip() for line in open("prompts.txt", "r").readlines()]

normalcount = [calc_score(line) for line in normallines if len(line) > 0]
perturbedcount = [calc_score(line) for line in perturbedlines if len(line) > 0]
altnormalcount = [calc_score(line) for line in altnormallines if len(line) > 0]
altperturbedcount = [calc_score(line) for line in altperturbedlines if len(line) > 0]

print("Raw Normal:", sum(normalcount) / len(normalcount))
print("Raw Perturbed:", sum(perturbedcount) / len(perturbedcount))
print("Alt Normal:", sum(altnormalcount) / len(altnormalcount))
print("Alt Perturbed:", sum(altperturbedcount) / len(altperturbedcount))

repetitioncount = 0

for index, prompt in enumerate(prompts):
    perturbedline = perturbedlines[index]
    ngrams1 = set(ngrams(prompt.split(), 5))
    ngrams2 = set(ngrams(perturbedline.split(), 5))
    if (len(ngrams2) == 0):
        continue
    if (len(ngrams1.intersection(ngrams2)) / len(ngrams2) > 0.2):
        repetitioncount += 1
print("Raw Repetition Count:", repetitioncount)